{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071c7845",
   "metadata": {},
   "source": [
    "下面介绍基于循环神经网络的编码器和解码器的代码实现。首先是作为编码器的循环神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f682c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dc8af74",
   "metadata": {},
   "source": [
    "接下来是作为解码器的另一个循环神经网络的代码实现。<!--我们使用编码器最终的输出用作解码器的初始隐状态，这个输出向量有时称作上下文向量（context vector），它编码了整个源序列的信息。解码器最初的输入词元是“\\<sos\\>”（start-of-string）。解码器的目标为，输入编码器隐状态，一步一步解码出整个目标序列。解码器每一步的输入可以是真实目标序列，也可以是解码器上一步的预测结果。-->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9beee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38306ed9",
   "metadata": {},
   "source": [
    "下面介绍基于注意力机制的循环神经网络解码器的代码实现。\n",
    "我们使用一个注意力层来计算注意力权重，其输入为解码器的输入和隐状态。\n",
    "这里使用Bahdanau注意力（Bahdanau attention），这是序列到序列模型中应用最广泛的注意力机制，特别是机器翻译任务。该注意力机制使用一个对齐模型（alignment model）来计算编码器和解码器隐状态之间的注意力分数，具体来讲就是一个前馈神经网络。相比于点乘注意力，Bahdanau注意力利用了非线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d675aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63b031",
   "metadata": {},
   "source": [
    "\n",
    "接下来我们实现基于Transformer的编码器和解码器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9a91cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e8cfc9c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e58b811",
   "metadata": {},
   "source": [
    "下面以机器翻译（中-英）为例展示如何训练序列到序列模型。这里使用的是中英文Books数据，其中中文标题来源于第4章所使用的数据集，英文标题是使用已训练好的机器翻译模型从中文标题翻译而得，因此该数据并不保证准确性，仅用于演示。\n",
    "\n",
    "首先需要对源语言和目标语言分别建立索引，并记录词频。\n",
    "\n",
    "<!-- 代码来自：\n",
    "\n",
    "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <span style=\"color:blue;font-size:20px\">BSD-3-Clause license</span>\n",
    "\n",
    "下载链接：\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.en\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.zh\n",
    " -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6b258fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01ae1e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6a561e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['i n d e s i g n c c 核 心 应 用 案 例 教 程 （ 全 彩 慕 课 版 ）', 'indesign cc core application case studies curriculum (purpose full-coloured version)']\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfa6874b",
   "metadata": {},
   "source": [
    "为了便于训练，对每一对源-目标句子需要准备一个源张量（源句子的词元索引）和一个目标张量（目标句子的词元索引）。在两个句子的末尾会添加“\\<eos\\>”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "92baa7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411c9ed3",
   "metadata": {},
   "source": [
    "<!--训练时，我们使用编码器编码源句子，并且保留每一步的输出向量和最终的隐状态。然后解码器以“\\<sos\\>”作为第一个输入，以及编码器的最终隐状态作为它的第一个隐状态。-->\n",
    "接下来是训练代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8b85d1ba",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, loss=0.0054: 100%|█| 20/20 [19:41<00:00, 59.09s/it\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAASPBJREFUeJzt3Xd8U1X/B/BPkrbpXnRDKYUCpZQ9CwgoewmKj4qogBtxoIKMR5aoBSdDHvRRH0B/Kg6Gyt5775ZRKBQo0AF0052c3x+loWmTNk2T3iT9vF+vvl7NveeefG9T7Mdzzz1XJoQQICIiIrJAcqkLICIiItKHQYWIiIgsFoMKERERWSwGFSIiIrJYDCpERERksRhUiIiIyGIxqBAREZHFspO6gJpQq9W4desW3NzcIJPJpC6HiIiIDCCEQHZ2NoKCgiCXVz5mYtVB5datWwgODpa6DCIiIjJCYmIiGjRoUGkbqw4qbm5uAEpO1N3dXeJqiIiIyBBZWVkIDg7W/B2vjFUHldLLPe7u7gwqREREVsaQaRucTEtEREQWi0GFiIiILBaDChEREVksBhUiIiKyWAwqREREZLEYVIiIiMhiMagQERGRxWJQISIiIovFoEJEREQWi0GFiIiILBaDChEREVksBhUiIiKyWAwqeuQVqqQugYiIqM5jUNHh3K0stJi5CdPXxEhdChERUZ3GoKLDkp3xAIBfDl+XuBIiIqK6jUFFB7lcJnUJREREBAYVnewVDCpERESWgEFFh0fbBAEAWga5S1wJERFR3SZpUFGpVJgxYwZCQ0Ph5OSEJk2aYO7cuRBCSFkW7OQlPxaVWto6iIiI6jo7Kd98/vz5WLp0KVasWIGWLVvi2LFjGDduHDw8PPDWW29JVpejfUlQyeUtykRERJKSNKgcOHAAw4cPx5AhQwAAjRo1wq+//oojR45IWRbcHO0BADkFxZLWQUREVNdJeumnW7du2L59Oy5evAgAOH36NPbt24dBgwbpbF9QUICsrCytL3NwUSoAMKgQERFJTdIRlalTpyIrKwvh4eFQKBRQqVT4+OOPMXr0aJ3to6OjMWfOHLPX5WRfElQKi9VQqwVvVyYiIpKIpCMqv//+O37++Wf88ssvOHHiBFasWIHPP/8cK1as0Nl+2rRpyMzM1HwlJiaapS47xYMfS5FabZb3ICIioqpJOqIyefJkTJ06FU8//TQAoFWrVrh27Rqio6MxZsyYCu2VSiWUSqXZ6yq7jkqRSkAp6U+JiIio7pJ0RCU3NxdyuXYJCoUCaolHMezLjKgUqziiQkREJBVJxwqGDRuGjz/+GA0bNkTLli1x8uRJfPnll3jhhRekLAt2cu0RFSIiIpKGpEFl8eLFmDFjBl5//XWkpqYiKCgIr776KmbOnCllWZDJZLCTy1CsFijmHBUiIiLJSBpU3NzcsGDBAixYsEDKMnSyU9wPKhxRISIikgyf9aOH/f25M0Wco0JERCQZBhU97O7f+cM5KkRERNJhUNFDwQcTEhERSY5BRY/SG38EGFSIiIikwqCih1xWklQEcwoREZFkGFT0KB1RUTOpEBERSYZBRQ/Z/REVTlEhIiKSDoOKHjKOqBAREUmOQUUPzlEhIiKSHoOKHpq7fphUiIiIJMOgooecc1SIiIgkx6CiB+eoEBERSY9BRY8HIyoMKkRERFJhUNFDppmjIm0dREREdRmDih6864eIiEh6DCp6yHjph4iISHIMKnpwCX0iIiLpMajowUs/RERE0mNQ0YMjKkRERNJjUNFDxhEVIiIiyTGo6MEF34iIiKTHoKIHl9AnIiKSHoOKHnwoIRERkfQYVPSQcUSFiIhIcgwqevCuHyIiIukxqOihWUdF4jqIiIjqMgYVPWSco0JERCQ5BhU95HzWDxERkeQYVPTQTKZVS1wIERFRHcagogcn0xIREUmPQUUPTqYlIiKSHoOKHlzwjYiISHoMKnpxwTciIiKpMajowTkqRERE0mNQ0YMPJSQiIpIeg4oe8vs/Gc5RISIikg6Dih6l66gwpxAREUmHQUUPrkxLREQkPQYVPe7PpeUcFSIiIgkxqOjBdVSIiIikx6CiBy/9EBERSY9BRQ9OpiUiIpIeg4oeDxZ8k7YOIiKiuoxBRQ9e+iEiIpIeg4oeMk6mJSIikhyDih4yLqFPREQkOQYVPfhQQiIiIukxqOgh510/REREkmNQ0YMLvhEREUmPQUWP0jkqcSnZyMovkrgaIiKiuolBRY/Su342n03BQ/N3SlsMERFRHcWgokfpHBUAyMzjiAoREZEUGFT0kMuqbkNERETmxaCiR9kRFSIiIpIGg4oeMgYVIiIiyTGo6MFLP0RERNJjUNGDAypERETSY1DRo/wcFa6lQkREVPsYVPQoP0dFpeIKtURERLWNQUWfckvnX0/LlagQIiKiuotBRY9FO+K1Xn+2OU6iSoiIiOouBhUDqflwQiIiolrHoGKg5Mx8qUsgIiKqcxhU9Aj1cdF6feXOPYkqISIiqrsYVPRIYDAhIiKSHINKNTz3w2FcSsmWugwiIqI6g0FFjxFtgyps23vpDl7+8ZgE1RAREdVNDCp6+LopdW5P4qRaIiKiWsOgokevZn46t/MZQERERLWHQUWPlkHuOrfLZTLkFapquRoiIqK6iUFFDy8XB53bcwtVaDFzE67czqnlioiIiOoeBhUj/XjwmtQlEBER2TzJg8rNmzfx7LPPol69enByckKrVq1w7Jhl3FnzwZAWUpdARERUp0kaVNLT09G9e3fY29tj48aNOHfuHL744gt4eXlJWZbGSw81RniAm859yw9cheDzf4iIiMzKTso3nz9/PoKDg7Fs2TLNttDQUL3tCwoKUFBQoHmdlZVl1voAwNlBoXff2VtZ+OvUTYzrHoogTyez10JERFTXSDqi8vfff6Njx47417/+BT8/P7Rr1w7fffed3vbR0dHw8PDQfAUHB5u9xrkjIvXuG7p4H77bm4Cxy46YvQ4iIqK6SNKgcuXKFSxduhRNmzbF5s2bMX78eLz11ltYsWKFzvbTpk1DZmam5isxMdHsNbYM8qiyzcUU3gFERERkDpJe+lGr1ejYsSM++eQTAEC7du0QGxuLb775BmPGjKnQXqlUQqnUvWIsERER2R5JR1QCAwMRERGhta1Fixa4fv26RBURERGRJZE0qHTv3h1xcXFa2y5evIiQkBCJKtJNzmXziYiIJCFpUHnnnXdw6NAhfPLJJ4iPj8cvv/yC//73v5gwYYKUZVXwfy92kboEIiKiOknSoNKpUyesWbMGv/76KyIjIzF37lwsWLAAo0ePlrKsCvQtp09ERETmJelkWgAYOnQohg4dKnUZlVIYcO3nUko2mvrrXhyOiIiIjCP5EvrWQC6rOqj0+2pPLVRCRERUtzCoGMCQERUAOpfUz8wrQpFKbeqSiIiI6gQGFQMoDBhRAYDU7AKt13dyCtBmzhb0+3K3OcoiIiKyeQwqBpAb+FOSAbhXUIyjV9OgVgvsu3QHAHD1bq75iiMiIrJhkk+mtQaGXvqBDHhpxTEcvHIX8x5vBadKHmhIREREVeOIigEMvfQjgwwHr9wFAPywLwG3MvK19uuaw0JERET6cUTFAHIDR1TK55lbGXma7x/5fBfkchnCfF3xzXMdTFkeERGRzeKIigEMH1F54FJqDjqFemteX7lzD/GpOdh0NtnE1REREdkuBhUDGJhTsO18itbrf07fMkM1REREdQeDigE8nOwxrE1Qle2mrIrRer31XIqelkRERGQIBhUDyGQyLB7VDm6OnNJDRERUmxhUqkFpZ5ofF+/+ISIiMgyDSjU4KEwVVEzSDRERkc1jUKkGBxONqKiZVIiIiAzCoFINHRt5V93IAGrmFCIiIoMwqFRDkKeTSfrhiAoREZFhGFSqoWWQu0n6UXFIhYiIyCAMKtXQP8If80e2wro3e9Son8MJd01UERERkW1jUKkGmUyGpzo1RGR9jxr188LyYyaqiIiIyLYxqBAREZHFYlAx0rNdG0pdAhERkc1jUDHSEx2CpS6BiIjI5jGoGMlObuAjlYmIiMhoDCpGUjCoEBERmR2DipEYVIiIiMyPQcVIDCpERETmx6BiJM5RISIiMj8GFSPJZQ+CSrcm9SSshIiIyHYxqBip7KWfAHdHCSshIiKyXQwqRiozoIKBkQFo5u9areMLi9UmroiIiMj2MKiYgIvSDlve6VWtY3IKis1UDRERke1gUDGSDA+GVMrOVzH8eCIiIqoKg4oJhPlV77IPADjY8UdPRERUFf61NAEHRcmPsToDK8JMtRAREdkSBhUjlQ0l4n7sENVIH6I6jYmIiOooBhUjaQWV+5mjaTUuAd0rUJm4IiIiItvDoGIkX1clejf3xSPhfvB0tgcAfD+mo8HHx6fmmKs0IiIim2EndQHWSiaTYfm4zlrbQuq5IMDdEclZ+VUeLzhLhYiIqEocUTExQyfUcooKERFR1RhUiIiIyGIxqJjYrGERBrXz5/OBiIiIqsSgYmIDIwNxelZ/7J7cu9J2XEKfiIioagwqZuDhZI+Qei7YNam3ZtvYbo202nyz+3LtFkVERGSFGFTMqJGPi+Z7hVx7lm2BCZ+eHHszE0MW7cXeS7dN1icREZElYFCpJeVyCopVpgsqY5cdxdlbWXjuhyMm65OIiMgSMKjUEnm5pFKsNt39yRm5hSbri4iIyJIwqNSSzo28tV6rqggqQgjsj7+DFAMWjyMiIrJVXJnWzPa+/zAuJGfjkXA/re3Hr6VDpRaauSv7Lt1BgIcSYX5uAIA1J2/i3d9PAwCuzhtSu0UTERFZCI6omFmwtzP6RfhDpmPJ2n9O3wIAXEzJxrM/HEbfL/cgM7cIADQhhYiIqC5jUJHQrcw8AMCeiw/u1mnz4RZe7iEiIrqPQaUWzR0RqfW69Hk/X229qLV93LKj1eqXjw0iIiJbxaBSi57rGoL2DT01r386eA0AkFuk0mp3LikLTf1ca7M0IiIii8SgUsv83B484yc5Kx/FKrXOJym3DfbUfK9SCwg+bpmIiOog3vVTy+TlouHKo4k625Wde9tk+gYAvPuHiIjqHo6o1DIZtO/++WBtrAn6JCIisk0MKrXs7K1Mg9qdS8qqsO1mRp6pyyEiIrJoDCq17OrdXIPapWYVVNhWaMIHGRIREVkDBhULlZpdMajcSM9FXqEKCXfuSVARERFR7eNkWisyfU0M5DIZrt3NxarxUegQ4l31QURERFaMIypWJDEtD9fuXzpadyZJ4mqIiIjMj0HFSh2Iv6v5niusEBGRrTIqqKxYsQLr16/XvH7//ffh6emJbt264dq1ayYrjvSLS8nG7L/PSl0GERGRWRkVVD755BM4OTkBAA4ePIglS5bg008/hY+PD9555x2TFmhrBkUGmKyv5Qeu4tpd7Ym1R6+mmax/IiIiqRkVVBITExEWFgYAWLt2LUaOHIlXXnkF0dHR2Lt3r0kLpModvZoOlfrBxZ9/fXNQwmqIiIhMy6ig4urqirt3S+ZIbNmyBf369QMAODo6Ii+Pi5LVpkl/nJa6BCIiIrMx6vbkfv364aWXXkK7du1w8eJFDB48GABw9uxZNGrUyJT12Rw+W5CIiMhwRo2oLFmyBFFRUbh9+zZWrVqFevXqAQCOHz+OUaNGmbRAW7PpbLLUJRAREVkNo0ZUPD098fXXX1fYPmfOnBoXRERERFTKqBGVTZs2Yd++fZrXS5YsQdu2bfHMM88gPT3dZMXZoiGtAqUugYiIyGoYFVQmT56MrKySp/vGxMTgvffew+DBg5GQkIB3333XpAXamj0Xb0tdAhERkdUw6tJPQkICIiIiAACrVq3C0KFD8cknn+DEiROaibWkW3ZBsdQlEBERWQ2jRlQcHByQm1vyzJlt27ahf//+AABvb2/NSEt1zZs3DzKZDBMnTjTqeCIiIrI9Ro2o9OjRA++++y66d++OI0eO4LfffgMAXLx4EQ0aNKh2f0ePHsW3336L1q1bG1MOERER2SijRlS+/vpr2NnZ4c8//8TSpUtRv359AMDGjRsxcODAavWVk5OD0aNH47vvvoOXl1elbQsKCpCVlaX1RURERLbLqBGVhg0bYt26dRW2f/XVV9Xua8KECRgyZAj69u2Ljz76qNK20dHRvAWaiIioDjEqqACASqXC2rVrcf78eQBAy5Yt8eijj0KhUBjcx8qVK3HixAkcPXrUoPbTpk3TuqsoKysLwcHB1SuciIiIrIZRQSU+Ph6DBw/GzZs30bx5cwAlox3BwcFYv349mjRpUmUfiYmJePvtt7F161Y4Ojoa9L5KpRJKpdKYkomIiMgKGTVH5a233kKTJk2QmJiIEydO4MSJE7h+/TpCQ0Px1ltvGdTH8ePHkZqaivbt28POzg52dnbYvXs3Fi1aBDs7O6hUKmNKIwBqNR8oREREtsGoEZXdu3fj0KFD8Pb21myrV68e5s2bh+7duxvUR58+fRATE6O1bdy4cQgPD8eUKVOqdQnJmjSq54yrd3PN+h7FagEHucys70FERFQbjBpRUSqVyM7OrrA9JycHDg4OBvXh5uaGyMhIrS8XFxfUq1cPkZGRxpRlFaYNbqFz+1MdTTfXRs1HNBMRkY0wKqgMHToUr7zyCg4fPgwhBIQQOHToEF577TU8+uijpq7RpgxoGYADUx+psH3+E6ZbQ4Y5hYiIbIVRl34WLVqEMWPGICoqCvb29gCAoqIiDB8+HAsWLDC6mF27dhl9rDXxdjFs1MlYKiYVIiKyEUYFFU9PT/z111+Ij4/X3J7cokULhIWFmbQ4WyWXmXf+iIqTaYmIyEYYHFSqeiryzp07Nd9/+eWXxldUB5h7nqvgiAoREdkIg4PKyZMnDWonM/NogS1QlEkqEYHuGNmh5PlIXz3VBp9vvoibGXk16p8jKkREZCsMDiplR0yoZsqGucfb18eLPUIBAI+1a4Dm/u4YvGhvjfpnTiEiIlth1F0/ZDp+7tqr8oYHuKFHmA8eb1/f6D556YeIiGyF0c/6IdPoEuqt9Voul+H/XuoCAFh94qZmu51chmIDh0o4okJERLaCIyoS2T/1EWya+BD83at+zlF9TyeE1HPWvH61Z2P4uT145lG7hp5a7bngGxER2QoGFYnU93RCeIC7QW0VchnKRo9pg1vg8PQ+mtflb3c2NqgIIXDt7j1eOiIiIovBoGIF5LKKq82WnZCrKB9U1NpthRA4kpCGzNyiSt9n7rrz6PXZLnyz+0qN6iUiIjIVBhUrYKeQVzrKIS/3KZYfUfn79C08+e1BDFlc+d1E/9ufAACYv+mCcYUSERGZGIOKFegY4oXKLsZUdelnY0wyAOBGeh42xiQhesN5qKuYcfvDvgT8ffqWUfUSERGZCu/6sWBb3+mJf84k4eWHQjFk0T697RTy8kFFe3/ZHDP+5xMAgPYhXhjQMkBvn3PXnQMAPNomqJpVExERmQ6DigVr6u+Gd/u5Aah8gmz5EZWyl4nUaoGNsckVjkm7V2iiKomIiMyHl36sxI10/cvql392UNkRlZOJ6TqPmbY6Bj3m70CRSq1zf1nzNl7AE0sPoKBYZVCtREREpsKgYgN2xt3Wel129CW/SH8QuZGeh1l/n62y/292X8axa+nYEJNkfJFERERGYFCxQaUPJSwsVmP094crbfvL4esG91tQSeghIiIyBwYVG1Q6oLJo+yWDj6nqLiAABi/hT0REZCoMKjao9NLPlTs5Bh+z6WzFCbdERERS410/Nqg0qJS/G0ifj9efQ2ae7lVry95BZMjEWyIiIlNiULFBpVdo9IWP8r7bm1BlXwBw7W5uTcoiIiKqNl76sQH1PZ20XpeOguy9dKfGfZe9g6hFoFuN+yMiIqoOBhUb0LVxPa3X3+013UMFjX0SMxERkSkwqFixdW/2wIs9QjFzWATmj2yl2b75bIrJ3qNsTmFmISKi2sagYiUmD2gOAHisXX3Ntsj6HpgxNAIeTvZ4qlNDrfZ5haZZRTYj17B5LkRERObAybRWYsLDYXi9dxPIDLyTp8XMTSZ5367R2zXfc0CFiIhqG0dUrIihIcVceOmHiIhqG4MKERERWSwGFTKY4MUfIiKqZQwqREREZLEYVGxIqI+LWfvnHBUiIqptDCo25NmuIVKXQEREZFIMKjbEQWHeu4I4oEJERLWNQcWGSH37MhERkakxqNgQubmDCiepEBFRLWNQsSFyM+cUlZpBhYiIaheDig0pNnOQYE4hIqLaxqBiQ7afN91Tk3VR89IPERHVMgYVG+LjqjRr/wwqRERU2xhUbMjcEZFm7V+lNmv3REREFTCo2BBHe4VZ+zd2REWlFpyIS0RERmFQqQN2Tuptkn7URoQNtVqgyfQNaDJ9A4o5JENERNXEoFIHhPq4wMPJvsb9qIwYUUnLLdR8f+RqWo1rICKiuoVBpY7Y9m4vvfsWjWpnUB/GXL2xlz/4FZOBK+cSEVH1MKjYuIhAdwCAr5sS297thca+D56wPCgyAMvGdcKgyACD+jLm0o+szG9YVn6R5vtilRrHrqahoFhV7T6JiKjuYFCxcT+M7aj5PszPFTve66153SbYEw8394OizNL7la3Cb8xk2oTb9zTfJ6blar7/cutFPPHNQbz7++lq90lERHUHg4oNuzB3IAI9nCps/+yJ1ugX4Y8xUY0AaIeTykZXjJmjkpSZp/k+2NtZ8/1/dl0GAKw/kwTB9VmIiEgPBhUbpu925X91DMZ3z3eEk0PJ/rJPXXZQ6P+VKJ8nojecx4Cv9uBeQbHeY8r2bafnYUTv/3lG7/FERFS3MajYqDYNPIw6ruwTmJ3KBZ3ya6F8u+cK4lKyserEDey7dAetZm/GTwevarVxVdppvp+38YLO9/zj+A2jaiUiItvHoGKjlEYu/iYvM+pxZnZ/rX365qio1ALP/nAY2fnFmPHXWWSUuSXZy9lB831yZr5RNRERUd3FoGKrjJz2UfbqjH25y0D67vopf0Fn2f6rD44pE27KPt257EgLERGRPgwqNsrY5e47hHgBAPzdKz7g0NDJtLmFJXNWdsalYujifZrtxWo18gpV+OnQNeRUMq+FiIioFP+31kY52BmXQf3dHXF6Vn842pccf2JGPwxYsAe3swv0Lvi24uA1rdeleWbcsqNa24tUAi1mbjKqLiIiqps4omJj5o9shUb1nPFRNZ+kPHd4SzzdKRi9m/vBw8keSruSOS7eLg4YExUCQP+ln4Q797Rex97K1JqnQkREZCyOqNiYpzo1xFOdGlb7uOfur6miS+kEW7UQKCxWw15R+VL4h66koe2HW6tdAxERUXkcUaEqld6y/PuxG2j2wUZM+OUECor5JGQiIjI/BhWqkqLcuvobYpK17uAhIiIyFwYVqpKu5/8Ye1cRERFRdTCoUJUUOpa+H7JorwSVEBFRXcOgQlWS6xhSSUzL09GSiIjItBhUqEpyPQ8TJCIiMjcGFapS+cm0REREtYVBharEARUiIpIKgwpV6e49rjJLRETSYFChKvEBgkREJBUGFaoSr/wQEZFUGFSoSrpuTyYiIqoNDCpUpdpYLl9wpVsiItKBQYWqtPLodbO/Bx8dREREujCoUJUycovM/h58dhAREenCoEIWgUGFiIh0YVChKo1oG2SyvpzsFTq3myqnqNQCr/50DIu3XzJNh0REJClJg0p0dDQ6deoENzc3+Pn5YcSIEYiLi5OyJNIhsr6H0cf2bOaL757vqHmt7wai6oyoqNUCV27n6JyAu+NCKjafTcEXWy9Wu1YiIrI8kgaV3bt3Y8KECTh06BC2bt2KoqIi9O/fH/fu3ZOyLCrn6c4NjT52yTPt0C/Cv8p21ZlM++KKo3jki92Yt/FChX15RarqlEdERBbOTso337Rpk9br5cuXw8/PD8ePH0fPnj0rtC8oKEBBQYHmdVZWltlrJMBVafivyfyRrTBlVQwAYFTnYLg52ht0XHVGVHbG3QYAfLvnCqYNbqG1L59BhYjIpljUHJXMzEwAgLe3t8790dHR8PDw0HwFBwfXZnl12sKn2+rc7miv/Ss0ol19zfe3swvKN9e7yq2xc1SGLd6HgQv2oPdnO3ExJRvv/3nGuI6IiMgiWUxQUavVmDhxIrp3747IyEidbaZNm4bMzEzNV2JiYi1XWXd5Ojvo3D5jaITWa0WZSSj6RlNe7dW4wjZjF3yLuZmJC8nZuHo3F/2/2mNUH0REZLksJqhMmDABsbGxWLlypd42SqUS7u7uWl9UO67f1T1v6IkODTTfzxwaATvFg18pPzelzmOe7FhxJMwcC74VqdSm75SIiGqVRQSVN954A+vWrcPOnTvRoEGDqg+gWtc9zEfndqWdAq/0bIwhrQMxrnsjrX23MvMrtJfJZPBxrRhgSueoFBar8dCnO/DDvoQa1/z7sUQUFjOsEBFZM0mDihACb7zxBtasWYMdO3YgNDRUynKoEo461j/pEloyl2j64BZY8kx7yMrde3wqMV1nXx5O9ljxQmetbaVB5fWfTyAxLQ9z152rcc3/XhOLpbsu17gfIiKSjqRBZcKECfi///s//PLLL3Bzc0NycjKSk5ORl5cnZVmkQ/knKF+YOxC/vRpV6TGJaRU/x9JeejXzRczs/prtpVNUtp1PqVGd5X21jeupEBFZM0mDytKlS5GZmYnevXsjMDBQ8/Xbb79JWRbpUH6hNl0jLOXpusRTlpujPewVJR1Xdnvytbv3EL3hPBLTcqsulIiIbIqk66gYe6cH1b6yH9UbD4cZdMwzXXQsFFcu8JRcLhJQ3Z9N66CQo/D+JNjUrHwMXrQXd3IKAQA/HbqGmNkDql07ERFZL4uYTEuWT+BBUtEZQAxUfh0VO/n9EZX7c14nlAlBb/xyUhNSACC3UIUm0zcY/d5ERGR9GFSo2lwdTTcQV7ruypU7OQCAnXGpmn3HrqWZ7H2IiMg6MaiQyZWuVturmW+FfeXvDFLcn6MydtlRxKdm41RihmafOdZWISIi6yLpHBWyHnbyB5m2/B1A5R2a1gc30vMMeupy2ZVsZ/191vgCiYjIJjGokEF83ZQY3aUh7BXyKh9S6OnsoHfJ/fLu3nswB2V//N0a1UhERLaHQYUM9vFjrWrcR6N6ziaohIiI6grOUaFasfr1bhgUGYCvn2kvdSlERGRFOKJCtaJ9Qy8sfbaD1GUQEZGV4YgKERERWSwGFSIiIrJYDCpERERksRhUiIiIyGIxqBAREZHFYlAh0kEIgSU747WePURERLWPtycT6bDjQio+2xwHALg6b4jE1RAR1V0cUSGb8t3zHWvcR3Z+EV768ViF7YlpuTh2lU90JiKqTRxRIZvSL8Iffm5KpGYXGN3HjLWxEGWe3LwpNhk/H76GvZfuAADG926C9wc0r/AkaCIiMj2OqJBV+N9Yw0dKOjbyqtF7bT6bovX6tf87rgkpALB012VsOZdS/jAiIjIDBhWyeOc/HIhHwv117msZ5F5h2yfVfHhiYbEa7/1+Gn+dugkAyCtSVXlMzI3Mar0HEREZh0GFLJ6Dnf5f019f6Vphm6ezQ6X9FRarsfNCKu4VFAMAfjt6HatO3MDbK08ZXJNczss+RES1gUGFLJ6iXChwVZZMrarn4gB3R3sMbhVQrf7mb7qAccuP4rX/Ow4AuF1uPoubY9VTtxScn0JEVCs4mZYs2vTB4RW2+bopsXNSb02gmDE0Ajcz8jGuWyOD+vxhXwIAYO+lOzh+Lb3C/uz84ir7SM7KM+i9iIioZjiiQpLyd1fq3ffbK13xSs8mFbarhYCvmxKO9goAQKCHE/6a0B0j2tWv9vuPXHoA9wofzElRqUUlrR/49Uhitd+LiIiqjyMqJCmlnULn9n1THkYDL2ed+wwNE4YqHWEBgKmrzpi0byIiqhmOqJCk3u7TVOd2fSEFgNYaJ6b2x/Eb5uuciIiqjUGFJDWyQwPsff/hah2jrmZSyTfgdmMiIrJMDCokuWBv7dGTF3uEVtq+upd+Jvx8QvN9+Tt8iIjIsjGokEV5vXcTfDCkRaVtqjtFZfuFVM2oysOf7zKyMiIikgKDClmU3s399D5Dp7GPCwCgT7hftfv9YkvJk5BzCqq+9ZiIiCwH7/ohi7B7cm9cuX0PnUO99bb55eWu2HIuGSPbN6iyPwc7OQqL1ZrX3+1NwAtVXFIiIiLLw6BCFiGkngtC6rlU2ibAwxHPRzUyqD+lQjuoAEBU9A5jyyMiIonw0g/ZJKU9f7WJiGwB/2tONknfQnJERGRdGFTIJmXmFZn9PYQZVp4rf7mKiKiuY1Ahm1Qbd/cUqqoXKm6k5+JIQhoA7ZCjVgvcysjDwAV70OyDjTh3K8ukdRIRWTMGFSIjFamqN6LSY/5OPPntQaw6fgMdP9qG/91/xtCEX06g27wduJCcDQAYsWS/yWslIrJWDCpkk5r5u5qsL4Vc97ouRUZepnnvj9O4e68QH647BwDYGJustb+6IzVERLaMtyeTTerauB4upuQY3H5I60AserodmkzfUGFf/MeDUFCsRkZuEZR2crSbuxUAUFSNQNH54206tzeaut7gPoiI6iIGFbJJDorqDRYueaa93n0ymQyO9goEeJTcSeRoL0d+kRoFBo6oJGXmIbUazxjqF+FvcFsiIlvHSz9kk+yqCCo9wnwQEegOAJU+vXlYm6AK2+zv9x13f05JVRZuu2RQu1J+bspqtScismUcUSGbFOZX+RyV78d0hKN9xbVW3u7TFAu3lwSL9wc2x+guIRXaKO3kyAbw0o/HsPDpthjWOgjy+/NYsvOL8Ph/DqBXM198vy8BDzX1gZ+bY7Vqv3r3XrXaExHZMgYVskmPt6uPreeSsflsis79+i4NvdOvGbLzi+GqVOD13mE629iXOfbtlafw0frzGBQZgA+HR2LR9ku4lJqDS6kl82P2XrpT7dqb+btV+5iqCCH0PuyRiMiSMaiQTZLLZfjm2Q6I3ngBqVn5mDWsJQSA7edT8Ei4n2YERJeZwyIq7dtOoX3s7ewC/HjwGmYNa4nv9ibUuHaFiQPF7L/PYldcKv55swfcHO1N2jcRkbkxqJDNkslkmD64hda2f3UMrnG/iWl5Orf3+3J3jfsGgJrmlNTsfPx86Dqe7hyMQA8nLD9wFQCw+sRNjOnWqMb1ERHVJk6mJTKRK3dMM7ekpsuovPbTcSzcfglj/3dUa3ttrNZLRGRqDCpEFqZYXbOkcuJ6BgAgLkX7rqTPNsfVqF8iIikwqBBZmOosJFdd5niQIhGROTGoEFkYQ58hJITAzgupSMnK19um1ezNWq9jbmbWqDYiotrGoEJkYYorGVHJyC2ESl0SZP45k4Rxy4+iyyfbsebkDQDAzL9itdpn52vPS8kv4nOEiMi68K4fIgkEeTjiVqbukZAite4Rlat37qH357vQMcQLf47vhl1xqZp97/x2Gu/8drrK91Xp6ZuIyFJxRIWoFhyZ3geBHo54tmtD/PNGD+yb8ojetvqeyvztnssAgGPX0gEAl1MNf+hiKc5RISJrwxEVolrg5+6IA1MfMWh12LIPMLyRnou/Tt3Cs11CsOr4Tc32P4/fwOkb1Z9vwphCRNaGIypE1bTw6bZGHVc+pLg4lDxr6I2Hw/DRiEjN9lOJGTh+LR33CooxcukBfLY5DtPXxuDtvk01bSb9UfVlHp01GHUUEZF0OKJCVE3D29bHP6eTsO287ucIlde6gYfO7Zsm9sSuuFT8q2MwHO0V+GDtg4mwI5ce0Gq7/kwS1p9JMr7oUkwqRGRlOKJCZISJ90c3Gvu4VNm2f4S/zu3B3s54LqqR5inOr/duYroC9cjILTL7exARmRKDCpERIut74PyHA7FjUu8K+/58LUrr9fMGPl+nUyNvE1RWuU83XdC7TwiBMf87gqf/e1Az6fZeQTF+PnwNqWXWavnvnstoNHU9UrPzsT/+DkYuPYC45Gx93RIR1Qgv/RAZyen+HJOyRndpiI5lAscHQ1rA3cAnFhfXwq3DafcKdW5/6NMdWg9bTMkqQICHIz7ZcB4/H76Ob3dfwZzhLbHtXAp+PnwdAND54+2a9gMW7MHVeUPMWzwR1UkcUSEykfqeTpgyKBxAyaWhyPruGNW5ocHHuzua7v8bAj0cdW5/7f7lpYzcQgxcsAdf77iE/CJVhSdCT1t9BgCw+WwyAOB6Wi7GLTuqCSnmJITgAxSJSINBhaiGJvZtio4hXtj+Xi/N6MnEvs2w7s2H4KI0PHx0DjXdpZ8/XovCjy90Rrcm9bS2H0lIAwC0/XArLiRn4/MtF1GgY92WnXG30WjqetzJ0T0CYy7nk7IQFb0DkbM2I5bL/RMRGFSIamxi32b4c3w3zaRYYxmyxooh6ns6IdDDCT2b+eKXl7tqXZLZFXe7Qvs2c7aY5H2rK+1eIY5eTdNahG7Qwr1Ivj8fZujifTiVmCFJbURkOThHhciCTBsUjuiN+ie8Vmbh020xpFUgAEAh1x96vt97xaj+K1PJ2+nVfu5WzfcT+zbFxL7NKrQZsWQ/574Q1XEMKkQW5NVeTSoNKh1CvPDxY5G4cvseOod6Y1NsMk4lZmD+yNaVhpOyPlp/3lTlavi7654TY6gF2y5hwbZLJqqGiGwJgwqRFVk1vhsAIDzAHQDwbNcQPNs1RMqSAOifvAuUTI69mJKDUB8X3M4pwIYzSRgYGVCL1RGRNWNQIbIwy8Z2wrjlRzWvz384ECOXHkBUuYmx1RHg7qiZ+2EORSr9t1avO5OEN389qbXt4w2mH9UxRMKde7hXUIyWQe4mmxNERObFoEJkYR4O98OZ2f3R/8s96B7mAycHBTa8/VCN+lzxQmcMWLCnxrUNbR2IdfeX8m8R6I7zSVkAgIJiFfKLVFhz8ia6hHojNbsA8ak5Wo8FMNam2GQo7eXoGOIFt3Jr0ny/9woc7RVITMvFt3uuoFczX4zq3BBdG5dcFnuiQwPsiruNl348hi+fbIN3f3/wjKTP/9UGT3RoUOP6ACA7vwitZm/B4+3q48un2pqkTyIqIRNW/Nz3rKwseHh4IDMzE+7u7lKXQ2RSarWA3JhZqno0mrre6GP/eC1Ks3LurYw8ODsocD4pG6O+O2Sq8gwydVA4XuvVBGq1wKoTNzD5zzM16s/Yibo5BcWY9PtpvNqrMdo19NL62XYM8cJvr0YZPGeIqC6qzt9vBhWiOkRfWOkf4Y85w1vind9OYVz3UPRq5ovwGZsAAH3C/fDD2E46jzt4+W6thxVTqiqoFKnUyC9S4bu9CVi0/RLCA9zww9hO6D5vR6XHLR/XCb2b+1Wrljs5Bbh2NxcdQrzwyYbz+O+eK3BQyHHx40GVHlesUmPJzsvoFlavVh7DUOr7vVew+sRNLB/XCX41nExd24pVavywLwFD2wShvqeT1OXUSQwqRKTTzL9i8ePBawCAX17qgme+PwwA2PDWQ4gI0v43pFILnE/KQotAd72jA7E3MzF08T6T19nQ2xlKOzkupeaYvO+yFjzVFo72CqTnFmLa6hh8+1wHvPrTcUTWL/lZxN7MMqrfJc+0R8sgd7z04zEseKotwvxc9a6zs+NCCq7cvqe5G6vs5wIAbYM9EVnfHVMGhle49AUAa0/exMTfTgHQHbzuFRTDwU6OFQeuwl4hx5gyz55aeeQ6pq6OQczs/jr7rkxp6O0f4Y8ZQyPQwMsJWXnFUChkyMorQpAFB4AXlh/FjgupACr+zHZfvI36no4I83OTorQ6g0GFiAxy5kYG0nOL0KuZr1HHbzmbjFd+Ol6jGs59OAB7L91BeIAbdl+8jcfbN4Cr0q5Gl6osVXN/N2x+pyee/PYgjiSkYckz7THhlxMGHz9zaARe6BEKAMjMK4Kr0g4frI3Fr0dKHm1Q/o9u2r1CrfVqAKBVfQ/E3MzUmm8EAKdn9oeHs3ZYuVdQDLUQFUJMkUqNpv/eWGmtbzwchsfa10cTX1dsik3C7ZxCPFeNO9Tyi1RQ2slRrBZYsjMe3cN84OJgh8GL9mLV+Ch0CDF+9Ch8xkbkF5WsyHxg6iPILVShgZcTdsXdxmv/V/L7nBA9GDKZDEIIJKblIdjbyegJ2LE3M6G0k6Opf83Dz8oj1+HhZI9B99dMSs3Kh4+r0qSXiWsDgwoR1Yo7OQXo+NE2g9rOGBoBtVpUuONH3+WXpv/eUOndRHVV7+a+eLJjMF7/uSTgvNuvGb7cehEAcHpWf7SZswUNvJywc1LvKsNEWQNbBqBPCz883r4BFHIZVGqBJtM3AAAufjQItzLy0PvzXVj5SlecT8rCnH/OGVW/vUKGmNkDUKwWiJy1GQBwds4AXLl9D/vi72D+pguQyYCq/jLtn/oIrt65h5ZB7nBR2sFeUflC6/+cvoW/Tt3C0mfb4+2VJ7EhJrlCGzelHbLvP2cq/uNBsFOUjETN+vssHmrqg59e7FLpewghkFuowpGENAR6OiI8wB13cwrQ4f6/kVXju8HH1QENvZ2NCj3/+uYAjl5NB1ASLNt8+GBV6VXju6FDiBcAYM4/Z7Fs/1Uc+6AvfFyVWn2o1AJCCKgFcCM9F419XTX7LqZkY/rqGMhkwE8vdsG+S3ew+WwyJvZrZvJLZAwqRFRr/jx+A5P+OF1pm/cHNsfrvcMAABeSs/D3qVsI9XFBr2a+euc3FBar0eyDqv/Qzh3eEs92DUHotA1a2/e+/zAe+nQn7BUyXPyoZJ6HTCbDf/dcxicbjFv9t65ysJOjUMczoSzJkFaBWDK6veb19vMpOJyQhm5N6sHDyR6P/edAtfu8Om+IzpG9FS90xpazyfj58HX0bu6r89EUlXm7T1M80aEB9ly6jdFdtEeZTidmwM3RThMgEtNyMW/jBSRn5eP4tfRK+73yyWDI5TKtmmcOjcDeS7excFQ7ONkrKoRXuQxQC+D75zvipR+P6e3b1CtEW11QWbJkCT777DMkJyejTZs2WLx4MTp37lzlcQwqRJYlI7cQX++Ix8gODdAi0B2ZuUUoKFYZPdmy7H9w+7bwg9JegfVnkhDo4YgBLQPw7yEtNP8nvSEmCa//fAJKOzmCvZ2x9Z2eev+v1RSXlZaN7YRezXxRrBZwsCup4clvDuLI1bQa9011S4tAd3zyWCRm/nUWMTV8GOfbfZpi4XbTr/Jcp4PKb7/9hueffx7ffPMNunTpggULFuCPP/5AXFwc/PwqnzXPoEJk2zJzi3D8ehp6N/Mz6TX4o1fT8K9vDlbYPrR1IGYOi0Dnj7cDANa92QOz/j6L/4xuDz83pSb4FBarNeGkLLVaYNDCvYhLya6wb+qgcMzT83iEVeOjcPxaOno398N3e66gbUNPDG0dhL2XbuOjdefx/ZiONZ60/Ei4n2YC6ePt62P1iZt4oXso/rc/oUb9ku1b83o3tGvoZdI+rSqodOnSBZ06dcLXX38NAFCr1QgODsabb76JqVOnVnosgwoRGUsIgSJVyWiIEEJr9CUlKx+O9gp4OFXvTphSarXA+6vO4NCVu3iheyj6Rfgj2NtZ671lMhl+PHgVoT4ueKhp1ZOZKxsFGtm+AbaeS0ZWfjFWvtIVT//3EAa09IcMMrQJ9sT43k0AlExQdVDItULfrYw8dKvidutSQR6OuJVZssLxqz0bY+qgcBxOSEOYnyvyClU4n5SFHRdS8cljrQAAX269iK93xhvUty4NvZ3xXv9m6NXMFw52ckxfHYNXezVBYbEaw5fsN7rfslyVdnCwk6OpnysOJ6Th4ea+2FnNSzlSmTIwHHmFxVi0w/ifcVXOfTgAzg6mXxvWaoJKYWEhnJ2d8eeff2LEiBGa7WPGjEFGRgb++usvrfYFBQUoKCjQvM7KykJwcDCDChHVCdfv5sLPXan3VmdTKR/chBAQAiYb1Sr9s1N6V01+kRpqIVCkUsPT2cGgPr7fewUXkrPx5/Ebmm2NfV3g46JERl4h8ovUeLJjA3yz+wqiH28FH1clvFzsER7gjtOJGWjq76rzD7AQAqcSM7TmtDT2ccGVO/fQqJ4znu7cEDsupCI1Kx8eTvYY2aEBHm0ThC+2XMSkAc3h7miHgmI1HO0VyCkoRmGxGt4uJed09lYmhiwqGRnrF+GPvZdua+4+0mf64HCM7hICF6Vdhc+l1JkbGXj064rB7cmODdDM3w0Ph/vB0V6By6k5iLmZiSc6NEBiWi7aN/SCADB22REo7eQY2y0Ud+8VoF2wF3IKihHm56pz5NAUrCao3Lp1C/Xr18eBAwcQFRWl2f7+++9j9+7dOHz4sFb72bNnY86cORX6YVAhIiKyHtUJKuaJSmYybdo0ZGZmar4SExOlLomIiIjMSNKHEvr4+EChUCAlJUVre0pKCgICKj4GXqlUQqlUVthOREREtknSERUHBwd06NAB27dv12xTq9XYvn271qUgIiIiqpskHVEBgHfffRdjxoxBx44d0blzZyxYsAD37t3DuHHjpC6NiIiIJCZ5UHnqqadw+/ZtzJw5E8nJyWjbti02bdoEf39/qUsjIiIiiUm+jkpNcB0VIiIi62Ozd/0QERFR3cKgQkRERBaLQYWIiIgsFoMKERERWSwGFSIiIrJYDCpERERksRhUiIiIyGIxqBAREZHFknxl2pooXasuKytL4kqIiIjIUKV/tw1Zc9aqg0p2djYAIDg4WOJKiIiIqLqys7Ph4eFRaRurXkJfrVbj1q1bcHNzg0wmM2nfWVlZCA4ORmJiok0uz8/zs362fo48P+tm6+cH2P45mvP8hBDIzs5GUFAQ5PLKZ6FY9YiKXC5HgwYNzPoe7u7uNvkLWIrnZ/1s/Rx5ftbN1s8PsP1zNNf5VTWSUoqTaYmIiMhiMagQERGRxWJQ0UOpVGLWrFlQKpVSl2IWPD/rZ+vnyPOzbrZ+foDtn6OlnJ9VT6YlIiIi28YRFSIiIrJYDCpERERksRhUiIiIyGIxqBAREZHFYlDRYcmSJWjUqBEcHR3RpUsXHDlyROqSdJo9ezZkMpnWV3h4uGZ/fn4+JkyYgHr16sHV1RUjR45ESkqKVh/Xr1/HkCFD4OzsDD8/P0yePBnFxcVabXbt2oX27dtDqVQiLCwMy5cvN8v57NmzB8OGDUNQUBBkMhnWrl2rtV8IgZkzZyIwMBBOTk7o27cvLl26pNUmLS0No0ePhru7Ozw9PfHiiy8iJydHq82ZM2fw0EMPwdHREcHBwfj0008r1PLHH38gPDwcjo6OaNWqFTZs2GD28xs7dmyFz3PgwIFWc37R0dHo1KkT3Nzc4OfnhxEjRiAuLk6rTW3+Tpr637Eh59e7d+8Kn+Frr71mFecHAEuXLkXr1q01C3xFRUVh48aNmv3W/PkZcn7W/vmVN2/ePMhkMkycOFGzzSo/Q0FaVq5cKRwcHMT//vc/cfbsWfHyyy8LT09PkZKSInVpFcyaNUu0bNlSJCUlab5u376t2f/aa6+J4OBgsX37dnHs2DHRtWtX0a1bN83+4uJiERkZKfr27StOnjwpNmzYIHx8fMS0adM0ba5cuSKcnZ3Fu+++K86dOycWL14sFAqF2LRpk8nPZ8OGDeLf//63WL16tQAg1qxZo7V/3rx5wsPDQ6xdu1acPn1aPProoyI0NFTk5eVp2gwcOFC0adNGHDp0SOzdu1eEhYWJUaNGafZnZmYKf39/MXr0aBEbGyt+/fVX4eTkJL799ltNm/379wuFQiE+/fRTce7cOfHBBx8Ie3t7ERMTY9bzGzNmjBg4cKDW55mWlqbVxpLPb8CAAWLZsmUiNjZWnDp1SgwePFg0bNhQ5OTkaNrU1u+kOf4dG3J+vXr1Ei+//LLWZ5iZmWkV5yeEEH///bdYv369uHjxooiLixPTp08X9vb2IjY2Vghh3Z+fIedn7Z9fWUeOHBGNGjUSrVu3Fm+//bZmuzV+hgwq5XTu3FlMmDBB81qlUomgoCARHR0tYVW6zZo1S7Rp00bnvoyMDGFvby/++OMPzbbz588LAOLgwYNCiJI/nHK5XCQnJ2vaLF26VLi7u4uCggIhhBDvv/++aNmypVbfTz31lBgwYICJz0Zb+T/karVaBAQEiM8++0yzLSMjQyiVSvHrr78KIYQ4d+6cACCOHj2qabNx40Yhk8nEzZs3hRBC/Oc//xFeXl6a8xNCiClTpojmzZtrXj/55JNiyJAhWvV06dJFvPrqq2Y7PyFKgsrw4cP1HmNN5yeEEKmpqQKA2L17txCidn8na+PfcfnzE6LkD13ZPwrlWdP5lfLy8hLff/+9zX1+5c9PCNv5/LKzs0XTpk3F1q1btc7JWj9DXvopo7CwEMePH0ffvn012+RyOfr27YuDBw9KWJl+ly5dQlBQEBo3bozRo0fj+vXrAIDjx4+jqKhI61zCw8PRsGFDzbkcPHgQrVq1gr+/v6bNgAEDkJWVhbNnz2ralO2jtE1t/zwSEhKQnJysVYuHhwe6dOmidT6enp7o2LGjpk3fvn0hl8tx+PBhTZuePXvCwcFB02bAgAGIi4tDenq6po1U57xr1y74+fmhefPmGD9+PO7evavZZ23nl5mZCQDw9vYGUHu/k7X177j8+ZX6+eef4ePjg8jISEybNg25ubmafdZ0fiqVCitXrsS9e/cQFRVlc59f+fMrZQuf34QJEzBkyJAKdVjrZ2jVDyU0tTt37kClUml9QADg7++PCxcuSFSVfl26dMHy5cvRvHlzJCUlYc6cOXjooYcQGxuL5ORkODg4wNPTU+sYf39/JCcnAwCSk5N1nmvpvsraZGVlIS8vD05OTmY6O22l9eiqpWytfn5+Wvvt7Ozg7e2t1SY0NLRCH6X7vLy89J5zaR/mMnDgQDz++OMIDQ3F5cuXMX36dAwaNAgHDx6EQqGwqvNTq9WYOHEiunfvjsjISM3718bvZHp6utn/Hes6PwB45plnEBISgqCgIJw5cwZTpkxBXFwcVq9ebTXnFxMTg6ioKOTn58PV1RVr1qxBREQETp06ZROfn77zA2zj81u5ciVOnDiBo0ePVthnrf8GGVSs2KBBgzTft27dGl26dEFISAh+//33WgsQZDpPP/205vtWrVqhdevWaNKkCXbt2oU+ffpIWFn1TZgwAbGxsdi3b5/UpZiFvvN75ZVXNN+3atUKgYGB6NOnDy5fvowmTZrUdplGad68OU6dOoXMzEz8+eefGDNmDHbv3i11WSaj7/wiIiKs/vNLTEzE22+/ja1bt8LR0VHqckyGl37K8PHxgUKhqDADOiUlBQEBARJVZThPT080a9YM8fHxCAgIQGFhITIyMrTalD2XgIAAnedauq+yNu7u7rUahkrrqeyzCQgIQGpqqtb+4uJipKWlmeSca/t3oHHjxvDx8UF8fLymLms4vzfeeAPr1q3Dzp070aBBA8322vqdNPe/Y33np0uXLl0AQOsztPTzc3BwQFhYGDp06IDo6Gi0adMGCxcutJnPT9/56WJtn9/x48eRmpqK9u3bw87ODnZ2dti9ezcWLVoEOzs7+Pv7W+VnyKBShoODAzp06IDt27drtqnVamzfvl3rGqalysnJweXLlxEYGIgOHTrA3t5e61zi4uJw/fp1zblERUUhJiZG64/f1q1b4e7urhkKjYqK0uqjtE1t/zxCQ0MREBCgVUtWVhYOHz6sdT4ZGRk4fvy4ps2OHTugVqs1/8GJiorCnj17UFRUpGmzdetWNG/eHF5eXpo2lnDON27cwN27dxEYGKipy5LPTwiBN954A2vWrMGOHTsqXIKqrd9Jc/07rur8dDl16hQAaH2Glnp++qjVahQUFFj951fV+elibZ9fnz59EBMTg1OnTmm+OnbsiNGjR2u+t8rPsNrTb23cypUrhVKpFMuXLxfnzp0Tr7zyivD09NSaAW0p3nvvPbFr1y6RkJAg9u/fL/r27St8fHxEamqqEKLkNrSGDRuKHTt2iGPHjomoqCgRFRWlOb70NrT+/fuLU6dOiU2bNglfX1+dt6FNnjxZnD9/XixZssRstydnZ2eLkydPipMnTwoA4ssvvxQnT54U165dE0KU3J7s6ekp/vrrL3HmzBkxfPhwnbcnt2vXThw+fFjs27dPNG3aVOv23YyMDOHv7y+ee+45ERsbK1auXCmcnZ0r3L5rZ2cnPv/8c3H+/Hkxa9Ysk9y+W9n5ZWdni0mTJomDBw+KhIQEsW3bNtG+fXvRtGlTkZ+fbxXnN378eOHh4SF27dqldXtnbm6upk1t/U6a499xVecXHx8vPvzwQ3Hs2DGRkJAg/vrrL9G4cWPRs2dPqzg/IYSYOnWq2L17t0hISBBnzpwRU6dOFTKZTGzZskUIYd2fX1XnZwufny7l72Syxs+QQUWHxYsXi4YNGwoHBwfRuXNncejQIalL0umpp54SgYGBwsHBQdSvX1889dRTIj4+XrM/Ly9PvP7668LLy0s4OzuLxx57TCQlJWn1cfXqVTFo0CDh5OQkfHx8xHvvvSeKioq02uzcuVO0bdtWODg4iMaNG4tly5aZ5Xx27twpAFT4GjNmjBCi5BblGTNmCH9/f6FUKkWfPn1EXFycVh93794Vo0aNEq6ursLd3V2MGzdOZGdna7U5ffq06NGjh1AqlaJ+/fpi3rx5FWr5/fffRbNmzYSDg4No2bKlWL9+vVnPLzc3V/Tv31/4+voKe3t7ERISIl5++eUK/6gt+fx0nRsArd+X2vydNPW/46rO7/r166Jnz57C29tbKJVKERYWJiZPnqy1Docln58QQrzwwgsiJCREODg4CF9fX9GnTx9NSBHCuj+/qs7PFj4/XcoHFWv8DGVCCFH9cRgiIiIi8+McFSIiIrJYDCpERERksRhUiIiIyGIxqBAREZHFYlAhIiIii8WgQkRERBaLQYWIiIgsFoMKERERWSwGFSKqkUaNGmHBggUGt9+1axdkMlmFB6MREenClWmJ6pjevXujbdu21QoXlbl9+zZcXFzg7OxsUPvCwkKkpaXB398fMpnMJDVU165du/Dwww8jPT0dnp6ektRARIaxk7oAIrI8QgioVCrY2VX9nwhfX99q9e3g4FDjx9kTUd3BSz9EdcjYsWOxe/duLFy4EDKZDDKZDFevXtVcjtm4cSM6dOgApVKJffv24fLlyxg+fDj8/f3h6uqKTp06Ydu2bVp9lr/0I5PJ8P333+Oxxx6Ds7MzmjZtir///luzv/yln+XLl8PT0xObN29GixYt4OrqioEDByIpKUlzTHFxMd566y14enqiXr16mDJlCsaMGYMRI0boPddr165h2LBh8PLygouLC1q2bIkNGzbg6tWrePjhhwEAXl5ekMlkGDt2LICSR9FHR0cjNDQUTk5OaNOmDf78888Kta9fvx6tW7eGo6MjunbtitjYWCM/ESKqCoMKUR2ycOFCREVF4eWXX0ZSUhKSkpIQHBys2T916lTMmzcP58+fR+vWrZGTk4PBgwdj+/btOHnyJAYOHIhhw4bh+vXrlb7PnDlz8OSTT+LMmTMYPHgwRo8ejbS0NL3tc3Nz8fnnn+Onn37Cnj17cP36dUyaNEmzf/78+fj555+xbNky7N+/H1lZWVi7dm2lNUyYMAEFBQXYs2cPYmJiMH/+fLi6uiI4OBirVq0CAMTFxSEpKQkLFy4EAERHR+PHH3/EN998g7Nnz+Kdd97Bs88+i927d2v1PXnyZHzxxRc4evQofH19MWzYMBQVFVVaDxEZyahnLhOR1Sr/2HchSh7ZDkCsXbu2yuNbtmwpFi9erHkdEhIivvrqK81rAOKDDz7QvM7JyREAxMaNG7XeKz09XQghxLJlywQAER8frzlmyZIlwt/fX/Pa399ffPbZZ5rXxcXFomHDhmL48OF662zVqpWYPXu2zn3laxBCiPz8fOHs7CwOHDig1fbFF18Uo0aN0jpu5cqVmv13794VTk5O4rffftNbCxEZj3NUiEijY8eOWq9zcnIwe/ZsrF+/HklJSSguLkZeXl6VIyqtW7fWfO/i4gJ3d3ekpqbqbe/s7IwmTZpoXgcGBmraZ2ZmIiUlBZ07d9bsVygU6NChA9Rqtd4+33rrLYwfPx5btmxB3759MXLkSK26youPj0dubi769euntb2wsBDt2rXT2hYVFaX53tvbG82bN8f58+f19k1ExmNQISINFxcXrdeTJk3C1q1b8fnnnyMsLAxOTk544oknUFhYWGk/9vb2Wq9lMlmloUJXe1HDGxJfeuklDBgwAOvXr8eWLVsQHR2NL774Am+++abO9jk5OQCA9evXo379+lr7lEpljWohIuNxjgpRHePg4ACVSmVQ2/3792Ps2LF47LHH0KpVKwQEBODq1avmLbAcDw8P+Pv74+jRo5ptKpUKJ06cqPLY4OBgvPbaa1i9ejXee+89fPfddwBKfgal/ZSKiIiAUqnE9evXERYWpvVVdh4PABw6dEjzfXp6Oi5evIgWLVrU6DyJSDeOqBDVMY0aNcLhw4dx9epVuLq6wtvbW2/bpk2bYvXq1Rg2bBhkMhlmzJhR6ciIubz55puIjo5GWFgYwsPDsXjxYqSnp1e6DsvEiRMxaNAgNGvWDOnp6di5c6cmTISEhEAmk2HdunUYPHgwnJyc4ObmhkmTJuGdd96BWq1Gjx49kJmZif3798Pd3R1jxozR9P3hhx+iXr168Pf3x7///W/4+PhUegcSERmPIypEdcykSZOgUCgQEREBX1/fSuebfPnll/Dy8kK3bt0wbNgwDBgwAO3bt6/FaktMmTIFo0aNwvPPP4+oqCi4urpiwIABcHR01HuMSqXChAkT0KJFCwwcOBDNmjXDf/7zHwBA/fr1MWfOHEydOhX+/v544403AABz587FjBkzEB0drTlu/fr1CA0N1ep73rx5ePvtt9GhQwckJyfjn3/+0YzSEJFpcWVaIrI6arUaLVq0wJNPPom5c+fW2vtyRVui2sdLP0Rk8a5du4YtW7agV69eKCgowNdff42EhAQ888wzUpdGRGbGSz9EZPHkcjmWL1+OTp06oXv37oiJicG2bds4gZWoDuClHyIiIrJYHFEhIiIii8WgQkRERBaLQYWIiIgsFoMKERERWSwGFSIiIrJYDCpERERksRhUiIiIyGIxqBAREZHF+n9rdycJaFPsyQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be47e115",
   "metadata": {},
   "source": [
    "下面实现贪心搜索解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "678192b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 弱 品 牌 ， 强 品 牌 ： 数 字 时 代 增 长 知 与 行\n",
      "target： weak brands , strong brands: information and action for growth in the digital age\n",
      "pred： weak brands , strong brands: information and action for growth in the digital age\n",
      "\n",
      "input： 细 胞 自 动 机 及 其 在 复 杂 网 络 中 的 应 用\n",
      "target： cellular self-motivation and its application in complex networks\n",
      "pred： cellular self-motivation and its application in complex networks\n",
      "\n",
      "input： 网 上 创 业 ： 商 业 模 式 + 操 作 实 战 + 案 例 分 析 （ 微 课 版 第 2 版 ）\n",
      "target： online entrepreneurship: business model + operational combat + case analysis (microcurricular version 2)\n",
      "pred： business etiquette (school editions)\n",
      "\n",
      "input： s p s s 统 计 分 析 从 入 门 到 精 通 ( 第 2 版 )\n",
      "target： spss statistical analysis from introduction to proficiency (version 2)\n",
      "pred： analysis from introduction to proficiency (version 2)\n",
      "\n",
      "input： e x c e l 机 器 学 习\n",
      "target： excel machine learning\n",
      "pred： excel machine learning deep learning based on opencv machine learning\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38320f0",
   "metadata": {},
   "source": [
    "\n",
    "接下来使用束搜索解码来验证模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3496efa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 大 学 生 信 息 素 养 基 础\n",
      "target： the information base of university students\n",
      "pred： university students computer practice curriculum for university students students guidance for university students guidance the university computer practice curriculum for university students guidance the university computer practice curriculum for the\n",
      "\n",
      "input： 办 公 应 用 与 计 算 思 维 案 例 教 程 第 2 版\n",
      "target： office application and computational casework curriculum , 2nd ed .\n",
      "pred： basic curriculum for the application of government and computational casework curriculum 2nd edition .\n",
      "\n",
      "input： 中 文 版 3 d s m a x 2 0 1 2 基 础 培 训 教 程 第 2 版\n",
      "target： chinese version 3ds max 2012 basic training curriculum 2nd edition\n",
      "pred： chinese version 3ds max 2012 basic training curriculum (version 2)\n",
      "\n",
      "input： 母 婴 红 利 3 . 0 母 婴 店 精 细 化 管 理 实 战 手 册\n",
      "target： mother and child dividend 3 .0 manual on the precision management of mother and child shops\n",
      "pred： manual on the time management is team two .\n",
      "\n",
      "input： v r 全 景 拍 摄 一 本 通\n",
      "target： v .r . panorama takes a general picture .\n",
      "pred： the photographer's late class .\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
